I’ve been reviewing the scikit-learn (scikit for short) library for several months, so I figured I’d do a multi-class decision tree classification example. Before I go any further, let me comment that machine learning beginners are often seduced by the visual elegance of decision trees, but tree classifiers have several weaknesses.
I used one of my standard datasets for multi-class classification. The data looks like:
1 0.24 1 0 0 0.2950 2 0 0.39 0 0 1 0.5120 0 1 0.63 0 1 0 0.7580 1 0 0.36 1 0 0 0.4450 2 . . .
Each line of data represents a person. The fields are sex (male = 0, female = 1), age (normalized by dividing by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), annual income (divided by 100,000), and politics type (conservative = 0, moderate = 1, liberal = 2). The goal is to predict the politics type of a person from their sex, age, State, and income.
It isn’t necessary to normalize age and income. Converting categorical predictors like State is conceptually tricky, but the bottom line is that in most scenarios it’s best to one-hot encode. For binary predictor variables, I recommend using 0 or 1 encoding, but again, there are a lot of subtle details.
The key lines of code are:
import numpy as np from sklearn import tree md = 3 # max depth print("Creating decision tree max_depth=" + str(md)) model = tree.DecisionTreeClassifier(max_depth=md) model.fit(train_x, train_y) print("Done ")
Decision trees are highly sensitive to overfitting. If you set a large max_depth you can get 100% classification on training data but the accuracy on test data and new previously unseen data will likely be very poor.
The accuracy of the model can be displayed using the built-in score() method or indirectly in the form of a confusion matrix:
from sklearn.metrics import confusion_matrix y_predicteds = model.predict(test_x) cm = confusion_matrix(test_y, y_predicteds) print("Confusion matrix: \n") # print(cm) # no formatting show_confusion(cm) # custom formatting
There are several ways to visualize a trained tree classifier. The model can be displayed as text pseudo-code like this:
pseudo = tree.export_text(model, ["sex", "age", "state0", "state1", "state2", "income"]) print("Model in pseudo-code: ") print(pseudo)
The wackiness of the pseudo-code points out a weakness of decision tress — they’re highly sensitive to changes in training data.
The tree model can be displayed graphically using the plot_tree() method like so:
import matplotlib.pyplot as plt plt.figure(figsize=(14,8), tight_layout=True) # w,h inches tree.plot_tree(model, feature_names=["sex", "age", "state0", "state1", "state2", "income"], class_names=["con", "mod", "lib"], fontsize=8) plt.show()
Anyway, the demo was a good refresher for me.
One of my favorite series of science fiction books is the Mars series by author Edgar Rice Burroughs. The fictional world has a lot of politics and races: Red Martians (human-like), Green (fierce, 15-feet tall with six arms), Yellow Martians (secretive), White Martians (predecessors to Red Martians), and Black Martians (evil).
“A Fighting Man of Mars” is the seventh book in the series. It was first published in book form in 1931. The book tells the story of low-born soldier Tan Hadron who sets of to rescue snooty noblewoman Sanoma. He has many adventures and recues and falls in love with beautiful slave Tavia — who turn out to be a princess.
Left: Cover art by Robert Abbett. Center: Cover art by Michael Whelan. Right: Cover art by Roy Krenkel.
Demo code. Replace “lte” with Boolean less-than-or-equal operator. The data is also listed below.
# people_politics_tree_sckit.py # predict politics (0 = con, 1 = mod, 2 = lib) # from sex, age, state, income # sex age state income politics # 0 0.27 0 1 0 0.7610 2 # 1 0.19 0 0 1 0.6550 0 # sex: 0 = male, 1 = female # state: michigan = 100, nebraska = 010, oklahoma = 001 # politics: conservative, moderate, liberal # Anaconda3-2020.02 Python 3.7.6 scikit 0.22.1 # Windows 10/11 import numpy as np from sklearn import tree # --------------------------------------------------------- def tree_to_pseudo(model, feature_names): # custom function to display tree model pseudo-code left = model.tree_.children_left right = model.tree_.children_right threshold = model.tree_.threshold features = [feature_names[i] for i in model.tree_.feature] value = model.tree_.value def recurse(left, right, threshold, features, node, depth=0): indent = " " * depth if (threshold[node] != -2): v = "%0.4f" % threshold[node] print(indent,"if ( " + features[node] + " lte " + str(v) + " ) {") if left[node] != -1: recurse(left, right, threshold, features, \ left[node], depth+1) print(indent,"} else {") if right[node] != -1: recurse(left, right, threshold, features, \ right[node], depth+1) print(indent,"}") else: idx = np.argmax(value[node]) # print(indent,"return " + str(value[node])) print(indent,"return " + str(model.classes_[idx])) recurse(left, right, threshold, features, 0) # --------------------------------------------------------- def show_confusion(cm): dim = len(cm) mx = np.max(cm) # largest count in cm wid = len(str(mx)) + 1 # width to print fmt = "%" + str(wid) + "d" # like "%3d" for i in range(dim): print("actual ", end="") print("%3d:" % i, end="") for j in range(dim): print(fmt % cm[i][j], end="") print("") print("------------") print("predicted ", end="") for j in range(dim): print(fmt % j, end="") print("") # --------------------------------------------------------- def main(): # 0. get ready print("\nBegin scikit decision tree example ") print("Predict politics from sex, age, State, income ") np.random.seed(0) np.set_printoptions(precision=4, suppress=True) # sex age state income politics # 0 0.27 0 1 0 0.7610 2 # 1 0.19 0 0 1 0.6550 0 # 1. load data print("\nLoading data into memory ") train_file = ".\\Data\\people_train.txt" train_xy = np.loadtxt(train_file, usecols=range(0,7), delimiter="\t", comments="#", dtype=np.float32) train_x = train_xy[:,0:6] train_y = train_xy[:,6].astype(int) test_file = ".\\Data\\people_test.txt" test_xy = np.loadtxt(test_file, usecols=range(0,7), delimiter="\t", comments="#", dtype=np.float32) test_x = test_xy[:,0:6] test_y = test_xy[:,6].astype(int) print("\nTraining data:") print(train_x[0:4]) print(". . . \n") print(train_y[0:4]) print(". . . ") # 2. create and train md = 3 print("\nCreating decision tree max_depth=" + str(md)) model = tree.DecisionTreeClassifier(max_depth=md) model.fit(train_x, train_y) print("Done ") # 3. evaluate acc_train = model.score(train_x, train_y) print("\nAccuracy on train = %0.4f " % acc_train) acc_test = model.score(test_x, test_y) print("Accuracy on test = %0.4f " % acc_test) # 3b. display formatted confusion matrix from sklearn.metrics import confusion_matrix y_predicteds = model.predict(test_x) cm = confusion_matrix(test_y, y_predicteds) print("\nConfusion matrix: \n") show_confusion(cm) # 4a. visualize using custom function # print("\nModel in pseudo-code: ") # tree_to_pseudo(model, ["sex", "age", # "state0", "state1", "state2", # "income"]) # 4b. use built-in export_text() pseudo = tree.export_text(model, ["sex", "age", "state0", "state1", "state2", "income"]) print("\nModel in pseudo-code: ") print(pseudo) # 4c. use built-in plot_tree() import matplotlib.pyplot as plt plt.figure(figsize=(14,8), tight_layout=True) # w,h inches tree.plot_tree(model, feature_names=["sex", "age", "state0", "state1", "state2", "income"], class_names=["con", "mod", "lib"], fontsize=8) plt.show() # 5. use model print("\nPredict for: M 35 Nebraska $55K ") X = np.array([[0, 0.35, 0,1,0, 0.5500]], dtype=np.float32) probs = model.predict_proba(X) print("\nPrediction pseudo-probs: ") print(probs) politic = model.predict(X) print("\nPredicted class: ") print(politic) # 6. TODO: save model using pickle # import pickle # print("Saving trained tree model ") # path = ".\\Models\\tree_scikit_model.sav" # pickle.dump(model, open(path, "wb")) # use saved model # X = np.array([[0, 0.35, 0,1,0, 0.5500]], # dtype=np.float32) # with open(path, 'rb') as f: # loaded_model = pickle.load(f) # pa = loaded_model.predict_proba(X) # print(pa) print("\nEnd scikit decision tree demo ") if __name__ == "__main__": main()
Training data. Replace commas with tab characters or modify program.
# people_train.txt # sex (M=0, F=1), age (div 100) # state (michigan = 100, nebraska = 010, # oklahoma = 001) # income (div 100,000) # politics (con = 0, mod = 1, lib = 2) # 1,0.24,1,0,0,0.2950,2 0,0.39,0,0,1,0.5120,1 1,0.63,0,1,0,0.7580,0 0,0.36,1,0,0,0.4450,1 1,0.27,0,1,0,0.2860,2 1,0.50,0,1,0,0.5650,1 1,0.50,0,0,1,0.5500,1 0,0.19,0,0,1,0.3270,0 1,0.22,0,1,0,0.2770,1 0,0.39,0,0,1,0.4710,2 1,0.34,1,0,0,0.3940,1 0,0.22,1,0,0,0.3350,0 1,0.35,0,0,1,0.3520,2 0,0.33,0,1,0,0.4640,1 1,0.45,0,1,0,0.5410,1 1,0.42,0,1,0,0.5070,1 0,0.33,0,1,0,0.4680,1 1,0.25,0,0,1,0.3000,1 0,0.31,0,1,0,0.4640,0 1,0.27,1,0,0,0.3250,2 1,0.48,1,0,0,0.5400,1 0,0.64,0,1,0,0.7130,2 1,0.61,0,1,0,0.7240,0 1,0.54,0,0,1,0.6100,0 1,0.29,1,0,0,0.3630,0 1,0.50,0,0,1,0.5500,1 1,0.55,0,0,1,0.6250,0 1,0.40,1,0,0,0.5240,0 1,0.22,1,0,0,0.2360,2 1,0.68,0,1,0,0.7840,0 0,0.60,1,0,0,0.7170,2 0,0.34,0,0,1,0.4650,1 0,0.25,0,0,1,0.3710,0 0,0.31,0,1,0,0.4890,1 1,0.43,0,0,1,0.4800,1 1,0.58,0,1,0,0.6540,2 0,0.55,0,1,0,0.6070,2 0,0.43,0,1,0,0.5110,1 0,0.43,0,0,1,0.5320,1 0,0.21,1,0,0,0.3720,0 1,0.55,0,0,1,0.6460,0 1,0.64,0,1,0,0.7480,0 0,0.41,1,0,0,0.5880,1 1,0.64,0,0,1,0.7270,0 0,0.56,0,0,1,0.6660,2 1,0.31,0,0,1,0.3600,1 0,0.65,0,0,1,0.7010,2 1,0.55,0,0,1,0.6430,0 0,0.25,1,0,0,0.4030,0 1,0.46,0,0,1,0.5100,1 0,0.36,1,0,0,0.5350,0 1,0.52,0,1,0,0.5810,1 1,0.61,0,0,1,0.6790,0 1,0.57,0,0,1,0.6570,0 0,0.46,0,1,0,0.5260,1 0,0.62,1,0,0,0.6680,2 1,0.55,0,0,1,0.6270,0 0,0.22,0,0,1,0.2770,1 0,0.50,1,0,0,0.6290,0 0,0.32,0,1,0,0.4180,1 0,0.21,0,0,1,0.3560,0 1,0.44,0,1,0,0.5200,1 1,0.46,0,1,0,0.5170,1 1,0.62,0,1,0,0.6970,0 1,0.57,0,1,0,0.6640,0 0,0.67,0,0,1,0.7580,2 1,0.29,1,0,0,0.3430,2 1,0.53,1,0,0,0.6010,0 0,0.44,1,0,0,0.5480,1 1,0.46,0,1,0,0.5230,1 0,0.20,0,1,0,0.3010,1 0,0.38,1,0,0,0.5350,1 1,0.50,0,1,0,0.5860,1 1,0.33,0,1,0,0.4250,1 0,0.33,0,1,0,0.3930,1 1,0.26,0,1,0,0.4040,0 1,0.58,1,0,0,0.7070,0 1,0.43,0,0,1,0.4800,1 0,0.46,1,0,0,0.6440,0 1,0.60,1,0,0,0.7170,0 0,0.42,1,0,0,0.4890,1 0,0.56,0,0,1,0.5640,2 0,0.62,0,1,0,0.6630,2 0,0.50,1,0,0,0.6480,1 1,0.47,0,0,1,0.5200,1 0,0.67,0,1,0,0.8040,2 0,0.40,0,0,1,0.5040,1 1,0.42,0,1,0,0.4840,1 1,0.64,1,0,0,0.7200,0 0,0.47,1,0,0,0.5870,2 1,0.45,0,1,0,0.5280,1 0,0.25,0,0,1,0.4090,0 1,0.38,1,0,0,0.4840,0 1,0.55,0,0,1,0.6000,1 0,0.44,1,0,0,0.6060,1 1,0.33,1,0,0,0.4100,1 1,0.34,0,0,1,0.3900,1 1,0.27,0,1,0,0.3370,2 1,0.32,0,1,0,0.4070,1 1,0.42,0,0,1,0.4700,1 0,0.24,0,0,1,0.4030,0 1,0.42,0,1,0,0.5030,1 1,0.25,0,0,1,0.2800,2 1,0.51,0,1,0,0.5800,1 0,0.55,0,1,0,0.6350,2 1,0.44,1,0,0,0.4780,2 0,0.18,1,0,0,0.3980,0 0,0.67,0,1,0,0.7160,2 1,0.45,0,0,1,0.5000,1 1,0.48,1,0,0,0.5580,1 0,0.25,0,1,0,0.3900,1 0,0.67,1,0,0,0.7830,1 1,0.37,0,0,1,0.4200,1 0,0.32,1,0,0,0.4270,1 1,0.48,1,0,0,0.5700,1 0,0.66,0,0,1,0.7500,2 1,0.61,1,0,0,0.7000,0 0,0.58,0,0,1,0.6890,1 1,0.19,1,0,0,0.2400,2 1,0.38,0,0,1,0.4300,1 0,0.27,1,0,0,0.3640,1 1,0.42,1,0,0,0.4800,1 1,0.60,1,0,0,0.7130,0 0,0.27,0,0,1,0.3480,0 1,0.29,0,1,0,0.3710,0 0,0.43,1,0,0,0.5670,1 1,0.48,1,0,0,0.5670,1 1,0.27,0,0,1,0.2940,2 0,0.44,1,0,0,0.5520,0 1,0.23,0,1,0,0.2630,2 0,0.36,0,1,0,0.5300,2 1,0.64,0,0,1,0.7250,0 1,0.29,0,0,1,0.3000,2 0,0.33,1,0,0,0.4930,1 0,0.66,0,1,0,0.7500,2 0,0.21,0,0,1,0.3430,0 1,0.27,1,0,0,0.3270,2 1,0.29,1,0,0,0.3180,2 0,0.31,1,0,0,0.4860,1 1,0.36,0,0,1,0.4100,1 1,0.49,0,1,0,0.5570,1 0,0.28,1,0,0,0.3840,0 0,0.43,0,0,1,0.5660,1 0,0.46,0,1,0,0.5880,1 1,0.57,1,0,0,0.6980,0 0,0.52,0,0,1,0.5940,1 0,0.31,0,0,1,0.4350,1 0,0.55,1,0,0,0.6200,2 1,0.50,1,0,0,0.5640,1 1,0.48,0,1,0,0.5590,1 0,0.22,0,0,1,0.3450,0 1,0.59,0,0,1,0.6670,0 1,0.34,1,0,0,0.4280,2 0,0.64,1,0,0,0.7720,2 1,0.29,0,0,1,0.3350,2 0,0.34,0,1,0,0.4320,1 0,0.61,1,0,0,0.7500,2 1,0.64,0,0,1,0.7110,0 0,0.29,1,0,0,0.4130,0 1,0.63,0,1,0,0.7060,0 0,0.29,0,1,0,0.4000,0 0,0.51,1,0,0,0.6270,1 0,0.24,0,0,1,0.3770,0 1,0.48,0,1,0,0.5750,1 1,0.18,1,0,0,0.2740,0 1,0.18,1,0,0,0.2030,2 1,0.33,0,1,0,0.3820,2 0,0.20,0,0,1,0.3480,0 1,0.29,0,0,1,0.3300,2 0,0.44,0,0,1,0.6300,0 0,0.65,0,0,1,0.8180,0 0,0.56,1,0,0,0.6370,2 0,0.52,0,0,1,0.5840,1 0,0.29,0,1,0,0.4860,0 0,0.47,0,1,0,0.5890,1 1,0.68,1,0,0,0.7260,2 1,0.31,0,0,1,0.3600,1 1,0.61,0,1,0,0.6250,2 1,0.19,0,1,0,0.2150,2 1,0.38,0,0,1,0.4300,1 0,0.26,1,0,0,0.4230,0 1,0.61,0,1,0,0.6740,0 1,0.40,1,0,0,0.4650,1 0,0.49,1,0,0,0.6520,1 1,0.56,1,0,0,0.6750,0 0,0.48,0,1,0,0.6600,1 1,0.52,1,0,0,0.5630,2 0,0.18,1,0,0,0.2980,0 0,0.56,0,0,1,0.5930,2 0,0.52,0,1,0,0.6440,1 0,0.18,0,1,0,0.2860,1 0,0.58,1,0,0,0.6620,2 0,0.39,0,1,0,0.5510,1 0,0.46,1,0,0,0.6290,1 0,0.40,0,1,0,0.4620,1 0,0.60,1,0,0,0.7270,2 1,0.36,0,1,0,0.4070,2 1,0.44,1,0,0,0.5230,1 1,0.28,1,0,0,0.3130,2 1,0.54,0,0,1,0.6260,0
Test data. Replace commas with tab characters or modify program.
0,0.51,1,0,0,0.6120,1 0,0.32,0,1,0,0.4610,1 1,0.55,1,0,0,0.6270,0 1,0.25,0,0,1,0.2620,2 1,0.33,0,0,1,0.3730,2 0,0.29,0,1,0,0.4620,0 1,0.65,1,0,0,0.7270,0 0,0.43,0,1,0,0.5140,1 0,0.54,0,1,0,0.6480,2 1,0.61,0,1,0,0.7270,0 1,0.52,0,1,0,0.6360,0 1,0.30,0,1,0,0.3350,2 1,0.29,1,0,0,0.3140,2 0,0.47,0,0,1,0.5940,1 1,0.39,0,1,0,0.4780,1 1,0.47,0,0,1,0.5200,1 0,0.49,1,0,0,0.5860,1 0,0.63,0,0,1,0.6740,2 0,0.30,1,0,0,0.3920,0 0,0.61,0,0,1,0.6960,2 0,0.47,0,0,1,0.5870,1 1,0.30,0,0,1,0.3450,2 0,0.51,0,0,1,0.5800,1 0,0.24,1,0,0,0.3880,1 0,0.49,1,0,0,0.6450,1 1,0.66,0,0,1,0.7450,0 0,0.65,1,0,0,0.7690,0 0,0.46,0,1,0,0.5800,0 0,0.45,0,0,1,0.5180,1 0,0.47,1,0,0,0.6360,0 0,0.29,1,0,0,0.4480,0 0,0.57,0,0,1,0.6930,2 0,0.20,1,0,0,0.2870,2 0,0.35,1,0,0,0.4340,1 0,0.61,0,0,1,0.6700,2 0,0.31,0,0,1,0.3730,1 1,0.18,1,0,0,0.2080,2 1,0.26,0,0,1,0.2920,2 0,0.28,1,0,0,0.3640,2 0,0.59,0,0,1,0.6940,2
Pingback: Multi-Class Classification Using a scikit Decision Tree -- Visual Studio Magazine
Pingback: An Example of Random Forest Classification Using the scikit Library | James D. McCaffrey